fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#5689
Conversation
…il in ModelTrain (5504)
mufaddal-rohawala
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes a bug where PipelineVariable objects (ParameterInteger, ParameterString) fail in safe_serialize due to their str() raising TypeError. The fallback approach using hasattr(data, 'expr') is fragile and duck-typing based, when a more robust solution would be to catch the PipelineVariable case explicitly. The tests are a good addition but have some issues including unused variables and redundant test logic.
| except TypeError: | ||
| # PipelineVariable.__str__ raises TypeError by design. | ||
| # If the isinstance check above didn't catch it (e.g. import | ||
| # path mismatch), fall back to returning the object directly |
There was a problem hiding this comment.
Using hasattr(data, 'expr') as a duck-typing check for PipelineVariable is fragile — any object with an expr attribute would match this condition. Since there's already a PipelineVariable isinstance check earlier in the function (lines 185-186), a more robust approach would be to repeat that isinstance check here, or better yet, import and check against the specific parameter types. This would also be more maintainable.
Alternatively, consider catching TypeError more broadly:
except TypeError:
try:
return str(data)
except TypeError:
# PipelineVariable.__str__ raises TypeError by design.
# Return the object directly so pipeline serialization can handle it.
return dataReturning data unconditionally in the inner except is safer than a duck-type check that could silently pass through unrelated objects. If the object truly can't be serialized or stringified, it's likely a PipelineVariable that should be passed through. If you want to be defensive, re-use the PipelineVariable isinstance check instead of hasattr.
sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py
Show resolved
Hide resolved
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"max_depth": max_depth}, | ||
| ) | ||
| # safe_serialize should return the PipelineVariable object directly |
There was a problem hiding this comment.
Imports should be at the top of the file, not inside test functions. Move from sagemaker.train.utils import safe_serialize to the module-level imports. This applies to all three new test methods.
|
|
||
|
|
||
| class TestModelTrainerHyperparametersPipelineVariable: | ||
| """Test that PipelineVariable objects in hyperparameters survive safe_serialize.""" |
There was a problem hiding this comment.
These tests in test_model_trainer_pipeline_variable.py are essentially duplicates of the tests already added in test_utils.py. The safe_serialize unit tests in test_utils.py are sufficient for testing the serialization behavior. These tests should instead verify the integration — that ModelTrainer correctly preserves PipelineVariable objects in its hyperparameters dict after construction, rather than just re-testing safe_serialize.
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat |
There was a problem hiding this comment.
ParameterFloat and PipelineSession are imported but ParameterFloat is never used in the new tests, and PipelineSession is also unused. Remove unused imports.
mufaddal-rohawala
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes a bug where PipelineVariable objects (ParameterInteger, ParameterString) fail in safe_serialize due to their str() raising TypeError. The fix adds a nested try/except in the fallback path and includes good test coverage. However, there are several issues: the return type annotation is problematic, the function signature change could cause downstream issues, and there's a missing from __future__ import annotations import.
|
|
||
|
|
||
| def safe_serialize(data): | ||
| def safe_serialize(data) -> "str | PipelineVariable": |
There was a problem hiding this comment.
The return type "str | PipelineVariable" is a forward reference string, but this module should use from __future__ import annotations at the top to enable PEP 604 union syntax properly. Also, this is a behavioral change to the function's contract — previously it always returned str, now it can return PipelineVariable. This could cause TypeError in downstream callers that expect a str (e.g., calling .encode(), concatenation, etc.). Have you audited all call sites of safe_serialize to ensure they handle a PipelineVariable return value correctly?
Suggestion: Add from __future__ import annotations at the module top, and change the annotation to:
def safe_serialize(data) -> str | PipelineVariable:Also consider using Union[str, PipelineVariable] with an explicit import if from __future__ import annotations is not already used in this module.
| return json.dumps(data) | ||
| except TypeError: | ||
| return str(data) | ||
| try: |
There was a problem hiding this comment.
The nested try/except is a reasonable defensive measure, but catching a bare TypeError from str(data) and silently returning the raw object is risky. This means any object whose __str__ raises TypeError (not just PipelineVariable) will be returned as-is, potentially causing unexpected behavior downstream. Consider being more explicit:
except TypeError:
try:
return str(data)
except TypeError:
if isinstance(data, PipelineVariable):
return data
raiseThis way, only known PipelineVariable objects get the pass-through treatment, and truly broken objects still raise.
|
|
||
|
|
||
| def safe_serialize(data): | ||
| def safe_serialize(data) -> "str | PipelineVariable": |
There was a problem hiding this comment.
Missing type annotation for the data parameter. Per SDK conventions, all public functions must have type annotations for parameters and return types.
def safe_serialize(data: Any) -> str | PipelineVariable:| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Unit tests for sagemaker.train.utils – specifically safe_serialize.""" | ||
| from __future__ import absolute_import |
There was a problem hiding this comment.
Per PEP 484 / SDK conventions, new modules should use from __future__ import annotations instead of from __future__ import absolute_import (which is a Python 2 artifact and unnecessary in Python 3).
|
|
||
|
|
||
| def test_safe_serialize_with_empty_string(): | ||
| assert safe_serialize("") == "" |
There was a problem hiding this comment.
Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where isinstance check for PipelineVariable might fail (e.g., import path mismatch). You could mock isinstance or create a mock object whose __str__ raises TypeError to verify the nested except path:
def test_safe_serialize_with_object_whose_str_raises_typeerror():
"""Objects whose __str__ raises TypeError should be returned as-is."""
class BadStr:
def __str__(self):
raise TypeError("cannot convert")
obj = BadStr()
result = safe_serialize(obj)
assert result is objThis directly tests the new fallback code path added in this PR.
|
|
||
| def test_hyperparameters_preserves_pipeline_variable_string(self): | ||
| """ParameterString in hyperparameters should be preserved in ModelTrainer.""" | ||
| optimizer = ParameterString(name="Optimizer", default_value="sgd") |
There was a problem hiding this comment.
These tests verify that ModelTrainer preserves PipelineVariable objects in the hyperparameters dict, but they don't test the actual serialization path (i.e., when safe_serialize is called during job creation). Consider adding a test that mocks the training job creation to verify that PipelineVariable values survive the full serialization pipeline, not just assignment.
Description
The
safe_serializefunction insagemaker-train/src/sagemaker/train/utils.pyalready has aPipelineVariableisinstance check (lines 185-186), but the user's SDK version (3.3.1) may not have this fix. Additionally, the fallbackexcept TypeError: return str(data)block is dangerous becausePipelineVariable.__str__()intentionally raisesTypeError, meaning if the isinstance check ever fails (e.g., import path mismatch, reload issues), the except block will re-raise. The fix needs to: (1) ensure the PipelineVariable isinstance check is solid, (2) make the except fallback more robust by catching the case wherestr()also raises TypeError, and (3) add missing unit tests forsafe_serializefrom the train utils module covering PipelineVariable inputs. There are no existing tests for thesafe_serializeinsagemaker-train/src/sagemaker/train/utils.py.Related Issue
Related issue: 5504
Changes Made
sagemaker-train/src/sagemaker/train/utils.pysagemaker-train/tests/unit/train/test_utils.pysagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.pyAI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat